-- -- -- -- 卷积的矩阵乘法实现：
-- -- -- --   https://www.cnblogs.com/shine-lee/p/10775831.html
-- -- -- --   https://blog.csdn.net/qq_40263477/article/details/104979609
-- -- 卷积的 fft 实现
-- --   https://blog.csdn.net/qq_37527608/article/details/120922348
-- -- 卷积与矩阵乘法算法调优：
-- --      -- --  https://blog.csdn.net/lizhengx/article/details/83246833
-- --      -- -- 卷积求导
-- --      -- --  https://blog.csdn.net/wangjianyu0115/article/details/62222301
-- --      
-- --      -- set search_path to sm_sc;
-- drop function if exists sm_sc.fv_d_conv_2d_dloss_dindepdt_2(float[], float[], int[], int[2], int[4], float);
-- create or replace function sm_sc.fv_d_conv_2d_dloss_dindepdt_2
-- (
--   i_array              float[]                                     ,  -- 原矩阵(卷积第一目参数)
--   i_dloss_ddepdt       float[]                                     ,  -- 即已求出的损失函数对 y 的导数矩阵
--   i_window_len         int[]                                               ,  -- 卷积核窗口矩阵规格，如果是三维、四维，那么长度为三维或四维
--   i_stride             int[2]              default  array[1, 1]             ,  -- 纵向与横向步长
--   i_padding            int[4]              default  array[0, 0, 0, 0]       ,  -- 上下左右补齐行数/列数
--   i_padding_value      float      default  0.0                        -- 补齐填充元素值
-- )
-- returns float[]
-- as
-- $$
-- declare 
--   v_array_len_heigh        int   :=   array_length(i_array, array_ndims(i_array) - 1) ;
--   v_array_len_width        int   :=   array_length(i_array, array_ndims(i_array))     ;
--   v_dloss_ddepdt_len_heigh int   :=   array_length(i_dloss_ddepdt, array_ndims(i_dloss_ddepdt) - 1) ;
--   v_dloss_ddepdt_len_width int   :=   array_length(i_dloss_ddepdt, array_ndims(i_dloss_ddepdt)) ;
--   v_array_len_heigh_ex     int   :=   coalesce(i_padding[1], 0) + v_array_len_heigh + coalesce(i_padding[2], 0);     --   新背景矩阵高
--   v_array_len_width_ex     int   :=   coalesce(i_padding[3], 0) + v_array_len_width + coalesce(i_padding[4], 0);     --   新背景矩阵宽
--   v_window_len             int[2] :=   i_window_len[array_length(i_window_len, 1) - 1 : array_length(i_window_len, 1)];
--   v_ret                    float[];
-- begin
--   -- 审计
--   if current_setting('pg4ml._v_is_debug_check', true) = '1'
--   then
--     -- 审计二维长度
--     if array_ndims(i_array) not in (2, 3, 4) or array_ndims(i_dloss_ddepdt) not in (2, 3, 4) or array_ndims(i_dloss_ddepdt) <> array_ndims(i_array)
--     then 
--       raise exception 'unsupport ndims of i_array and i_dloss_ddepdt.';
--     elsif array_ndims(i_dloss_ddepdt) = 3 and array_length(i_dloss_ddepdt, 1) <> array_length(i_array, 1)
--       or (array_ndims(i_dloss_ddepdt) = 4 and (array_length(i_dloss_ddepdt, 1) <> array_length(i_array, 1) or array_length(i_dloss_ddepdt, 2) <> array_length(i_array, 2)))
--     then 
--       raise exception 'unmatch length at 1d of 3d or 1d / 2d of 4d';
--     elsif (v_array_len_heigh_ex - v_window_len[1]) % i_stride[1] <> 0
--     then 
--       raise exception 'imperfect window at heigh.';
--     elsif (v_array_len_width_ex - v_window_len[2]) % i_stride[2] <> 0
--     then 
--       raise exception 'imperfect window at width.';
--     elsif v_dloss_ddepdt_len_heigh <> (v_array_len_heigh_ex - v_window_len[1]) / i_stride[1] + 1
--       or v_dloss_ddepdt_len_width <> (v_array_len_width_ex - v_window_len[2]) / i_stride[2] + 1
--     then
--       raise exception 'unmatch length between y and dloss/dy.';
--     end if;
--   end if;
--   
--   if array_ndims(i_dloss_ddepdt) = 2
--   then 
--     i_array := 
--       sm_sc.fv_augmented
--       (
--         i_array, 
--         array[-i_padding[1] + 1, -i_padding[3] + 1], 
--         array[v_array_len_heigh + i_padding[2], v_array_len_width + i_padding[4]], 
--         i_padding_value
--       );
--     return 
--       |~~|
--       (
--         select 
--           sm_sc.fa_mx_sum
--           (
--             i_array[col_a_y : col_a_y + v_window_len[1] - 1][col_a_x : col_a_x + v_window_len[2] - 1] 
--               *` i_dloss_ddepdt[(col_a_y - 1) / i_stride[1] + 1][(col_a_x - 1) / i_stride[2] + 1]
--           )
--         from generate_series(1, v_array_len_heigh_ex - v_window_len[1] + i_stride[1], i_stride[1]) tb_a_y(col_a_y)
--           , generate_series(1, v_array_len_width_ex - v_window_len[2] + i_stride[2], i_stride[2]) tb_a_x(col_a_x)
--       );
--   
--   elsif array_ndims(i_dloss_ddepdt) = 3
--   then 
--     v_ret := 
--     (
--       select 
--         array_agg -- sm_sc.fa_mx_sum 
--         (
--           sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--           (
--             sm_sc.fv_mx_slice_3d_2_2d
--             (
--               i_array[a_cur_y : a_cur_y]
--             , 1
--             )
--           , sm_sc.fv_mx_slice_3d_2_2d
--             (
--               i_dloss_ddepdt[a_cur_y : a_cur_y]
--             , 1
--             )
--           , v_window_len     
--           , i_stride           
--           , i_padding       
--           , i_padding_value
--           )
--           order by a_cur_y
--         )
--       from generate_series(1, array_length(i_dloss_ddepdt, 1)) tb_a_cur_y(a_cur_y)
--     );
--     if array_length(i_window_len, 1) = 2
--     then 
--       v_ret := 
--         sm_sc.fv_mx_slice_3d_2_2d
--         (
--           sm_sc.fv_aggr_slice_sum_py(v_ret, array[array_length(v_ret, 1), 1, 1])
--         , 1
--         );
--     end if;
--     return v_ret;
--     
--   elsif array_ndims(i_dloss_ddepdt) = 4
--   then 
--     v_ret := 
--     (
--       with 
--       cte_agg_x as 
--       (
--         select 
--           a_cur_y, 
--           array_agg    -- sm_sc.fa_mx_sum 
--           (
--             sm_sc.fv_d_conv_2d_dloss_dindepdt_2
--             (
--               sm_sc.fv_mx_slice_4d_2_2d
--               (
--                 i_array[a_cur_y : a_cur_y][a_cur_x : a_cur_x][ : ][ : ]
--               , array[1, 2]
--               , array[1, 1]
--               )
--             , sm_sc.fv_mx_slice_4d_2_2d
--               (
--                 i_dloss_ddepdt[a_cur_y : a_cur_y][a_cur_x : a_cur_x][ : ][ : ]
--               , array[1, 2]
--               , array[1, 1]
--               )
--             , v_window_len     
--             , i_stride           
--             , i_padding       
--             , i_padding_value
--             )
--             order by a_cur_x
--           ) as a_agg_x
--         from generate_series(1, array_length(i_dloss_ddepdt, 1)) tb_a_cur_y(a_cur_y)
--           , generate_series(1, array_length(i_dloss_ddepdt, 2)) tb_a_cur_x(a_cur_x)
--         group by a_cur_y
--       )
--       select 
--         array_agg(a_agg_x order by a_cur_y)
--       from cte_agg_x
--     );
--     if array_length(i_window_len, 1) = 2
--     then 
--       v_ret := 
--         sm_sc.fv_mx_slice_4d_2_2d
--         (
--           sm_sc.fv_aggr_slice_sum_py(v_ret, array[array_length(v_ret, 1), array_length(v_ret, 2), 1, 1])
--         , array[1, 2]
--         , array[1, 1]
--         );
--     end if;
--     return v_ret;
--   end if;
-- end
-- $$
-- language plpgsql stable
-- parallel safe
-- cost 100;
-- -- -- set search_path to sm_sc;
-- -- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
-- --   (
-- --     array[[1.0,2.0,3.0,4.0,5.0,6.0,7.0]
-- --         , [10.0,20.0,30.0,40.0,50.0,60.0,70.0]
-- --         , [100.0,200.0,300.0,400.0,500.0,600.0,700.0]
-- --         , [-1.0,-2.0,-3.0,-4.0,-5.0,-6.0,-7.0]
-- --         , [-10.0,-20.0,-30.0,-40.0,-50.0,-60.0,-70.0]
-- --          ] :: float[]
-- --    , array[[1.1, 1.1, 1.1], [1.1, 1.1, 1.1]]
-- --    , array[3, 3]
-- --    , array[2, 2]
-- --   );
-- 
-- -- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
-- --   (
-- --     array[[1,2,3,4,5,6]
-- --         , [10,20,30,40,50,60]
-- --         , [100,200,300,400,500,600]
-- --         , [-1,-2,-3,-4,-5,-6]
-- --         , [-10,-20,-30,-40,-50,-60]
-- --          ] :: float[]
-- --    , array[[1.1, 1.1, 1.1], [1.1, 2.1, 1.1], [1.1, 1.1, 1.1]]
-- --    , array[3, 3]
-- --    , array[2, 2]
-- --    , array[1, 1, 1, 0]
-- --    , 0
-- --   );
-- 
-- -- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
-- --   (
-- --     array
-- --       [
-- --         [
-- --           [1,2,3,4,5,6]
-- --         , [10,20,30,40,50,60]
-- --         , [100,200,300,400,500,600]
-- --         , [-1,-2,-3,-4,-5,-6]
-- --         , [-10,-20,-30,-40,-50,-60]
-- --         ]
-- --       , [
-- --           [-1,2,-3,4,5,6]
-- --         , [10,-20,30,40,50,-60]
-- --         , [100,200,-300,400,500,600]
-- --         , [-1,2,-3,-4,-5,-6]
-- --         , [-10,-20,30,-40,50,-60]
-- --         ]
-- --       ]
-- --    , array[[[1.1, 1.1, -1.1], [1.1, -2.1, 1.1], [-1.1, 1.1, 1.1]],[[-1.1, 1.1, 1.1], [1.1, -2.1, 1.1], [1.1, 1.1, -1.1]]]
-- --    , array[3, 3]     --  array[2, 3, 3]
-- --    , array[2, 2]
-- --    , array[1, 1, 1, 0]
-- --    , 0
-- --   ) :: decimal[] ~=` 3;
-- 
-- -- select sm_sc.fv_d_conv_2d_dloss_dindepdt_2
-- --   (
-- --    array
-- --    [
-- --      [
-- --        [
-- --          [1,2,3,4,5,6]
-- --        , [10,20,30,40,50,60]
-- --        , [100,200,300,400,500,600]
-- --        , [-1,-2,-3,-4,-5,-6]
-- --        , [-10,-20,-30,-40,-50,-60]
-- --        ]
-- --      , [
-- --          [-1,2,-3,4,5,6]
-- --        , [10,-20,30,40,50,-60]
-- --        , [100,200,-300,400,500,600]
-- --        , [-1,2,-3,-4,-5,-6]
-- --        , [-10,-20,30,-40,50,-60]
-- --        ]
-- --      ]
-- --    , [
-- --        [
-- --          [1,2,3,-4,5,6]
-- --        , [10,20,30,40,50,60]
-- --        , [100,-200,300,400,500,600]
-- --        , [-1,2,-3,-4,5,-6]
-- --        , [10,20,-30,-40,-50,60]
-- --        ]
-- --      , [
-- --          [1,2,-3,-4,5,6]
-- --        , [10,-20,-30,40,50,60]
-- --        , [100,200,-300,-400,-500,600]
-- --        , [-1,2,-3,-4,5,-6]
-- --        , [-10,-20,30,-40,50,-60]
-- --        ]
-- --      ]
-- --    ]  
-- --    , array
-- --      [
-- --        [[[1.1, 1.1, -1.1], [-1.1, -2.1, -1.1], [-1.1, 1.1, 1.1]]
-- --        ,[[-1.1, 1.1, 1.1], [1.1, -2.1, 1.1], [1.1, 1.1, -1.1]]]
-- --      , [[[1.1, 1.1, -1.1], [1.1, -2.1, 1.1], [-1.1, 1.1, 1.1]]
-- --        ,[[-1.1, 1.1, 1.1], [-1.1, -2.1, -1.1], [1.1, 1.1, -1.1]]]
-- --      ]
-- --    , array[3, 3]    -- array[2, 2, 3, 3]
-- --    , array[2, 2]
-- --    , array[1, 1, 1, 0]
-- --    , 0
-- --   ) :: decimal[] ~=` 3;